# https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py

import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

__all__ = ["ViT"]

# helpers


def pair(t):
    return t if isinstance(t, tuple) else (t, t)


# classes


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.0):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)


class Attention(nn.Module):
    def __init__(
        self, dim, heads=12, dim_head=64, qkv_bias=False, attn_drop=0.0, proj_drop=0.0
    ):
        super().__init__()
        inner_dim = dim_head * heads
        assert dim == inner_dim
        self.heads = heads
        self.scale = dim_head**-0.5
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)
        self.proj = nn.Linear(inner_dim, dim)

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)
        attn = torch.matmul(q * self.scale, k.transpose(-1, -2))
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        out = torch.matmul(attn, v)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.proj(out)
        out = self.proj_drop(out)
        return out


class Transformer(nn.Module):
    def __init__(
        self,
        dim,
        depth,
        heads,
        dim_head,
        mlp_dim,
        qkv_bias=False,
        attn_drop=0.0,
        proj_drop=0.0,
    ):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(
                nn.ModuleList(
                    [
                        PreNorm(
                            dim,
                            Attention(
                                dim,
                                heads=heads,
                                dim_head=dim_head,
                                qkv_bias=qkv_bias,
                                attn_drop=attn_drop,
                                proj_drop=proj_drop,
                            ),
                        ),
                        PreNorm(
                            dim,
                            FeedForward(
                                dim,
                                mlp_dim,
                                dropout=proj_drop,
                            ),
                        ),
                    ]
                )
            )

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x


class ViT(nn.Module):
    def __init__(
        self,
        image_size=32,
        patch_size=4,
        num_classes=10,
        dim=768,
        depth=12,
        heads=12,
        mlp_ratio=4.0,
        pool="cls",
        channels=3,
        qkv_bias=False,
        dim_head=64,
        attn_drop=0.0,
        proj_drop=0.0,
        emb_drop=0.0,
        init_factor=1.0,
        weight_init="pytorch",
    ):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert (
            image_height % patch_height == 0 and image_width % patch_width == 0
        ), "Image dimensions must be divisible by the patch size."

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {
            "cls",
            "mean",
        }, "pool type must be either cls (cls token) or mean (mean pooling)"

        self.to_patch_embedding = nn.Sequential(
            Rearrange(
                "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
                p1=patch_height,
                p2=patch_width,
            ),
            nn.Linear(patch_dim, dim),
        )
        self.dim = dim
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.init_factor = init_factor
        self.weight_init = weight_init
        self.dropout = nn.Dropout(emb_drop)

        self.transformer = Transformer(
            dim,
            depth,
            heads,
            dim_head,
            int(dim * mlp_ratio),
            qkv_bias,
            attn_drop,
            proj_drop,
        )
        self.pool = pool
        self.mlp_head = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, num_classes))

        self.post_init(self.init_factor, self.weight_init)

    def _post_init(self, module, factor, weight_init="pytorch"):
        if hasattr(module, "post_init"):
            module.post_init(factor, weight_init)
        elif isinstance(module, nn.Linear):
            if weight_init == "xavier":
                print("Enable xavier for Transformer linear layer")
                fan_in = module.weight.data.shape[1]
                nn.init.normal_(module.weight, mean=0.0, std=factor * (fan_in**-0.5))
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
        elif isinstance(module, nn.LayerNorm):
            module.weight.data.fill_(factor * 1.0)
            module.bias.data.fill_(0.0)
        else:
            for mod in module.children():
                self._post_init(mod, factor)

    @torch.no_grad()
    def post_init(self, factor, weight_init="pytorch"):
        self._post_init(self.transformer, factor, weight_init)
        self._post_init(self.mlp_head, factor, weight_init)
        self._post_init(self.to_patch_embedding, factor, weight_init)
        if weight_init == "xavier":
            print("Enable xavier for Transformer")
            nn.init.normal_(self.pos_embedding, std=factor * (self.dim**-0.5))
            nn.init.normal_(self.cls_token, std=factor * (self.dim**-0.5))

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, "() n d -> b n d", b=b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, : (n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim=1) if self.pool == "mean" else x[:, 0]

        return self.mlp_head(x)
